import torch
import argparse
import numpy as np
from matplotlib import pyplot as plt
import os
import json
from tqdm import tqdm
from lama_inpaint import inpaint_img_with_lama_loaded, load_lama_model
from utils import load_img_to_array, save_array_to_img, dilate_mask, \
    show_mask
from collections import defaultdict


def main(
            annotation_file='../data/gqa/val_balanced_gqa_coco_captions_region_captions_scene_graphs.jsonl',
            image_folder='../data/vg/',
            mask_folder='../data/vg_samples/remove_anything/gsam_masks_text_box',
            debug=False):
    lama_config="lama/configs/prediction/default.yaml"
    lama_ckpt="./pretrained_models/big-lama"
    output_folder = mask_folder.replace("gsam_masks", "lama")
    model, predict_config = load_lama_model(lama_config, lama_ckpt)

    os.makedirs(output_folder, exist_ok=True)
    data = [json.loads(line.strip()) for line in open(annotation_file, "r")]
    if debug:
        data = data[:100]
    for d in tqdm(data):
        file_id = d["vg_id"]
        image_path = os.path.join(image_folder, f'{file_id}.jpg')
        assert os.path.exists(image_path)
        # get the list of objects that are with the same name
        name2obj = defaultdict(list)
        for obj_id, item in d["scene_graph"].items():
            if len(item["relations"]) == 0 or len(item["attributes"]) == 0:
                continue
            name = item["name"]
            if name in ["background"]:
                continue
            name2obj[name].append(obj_id)
        for name, obj_ids in tqdm(name2obj.items(), desc=f"Processing {file_id}"):
            obj_ids = sorted(obj_ids)
            output_filename = f'{file_id}.{"-".join(obj_ids)}'
            mask_file = os.path.join(mask_folder, f'{output_filename}_mask.npy')
            if not os.path.exists(mask_file):
                print(f"Mask file {mask_file} does not exist")
                continue
            remove(model, predict_config, image_path, mask_file, output_filename, output_folder)


def remove(model, predict_config, img_file, mask_file, output_filename, output_dir, dilate_kernel_size=15):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    img = load_img_to_array(img_file)


    masks = np.load(mask_file)
    num_objects = len(masks)
    masks = masks.sum(axis=0).astype(np.bool)
    masks = masks.astype(np.uint8) * 255

    # dilate mask to avoid unmasked edge effect
    if num_objects > 1:
      dilate_kernel_size = 20
    if dilate_kernel_size is not None:
      masks = [dilate_mask(mask, dilate_kernel_size) for mask in masks]
        
    # visualize the segmentation results
    for idx, mask in enumerate(masks):
        # path to the results
        mask_p = os.path.join(output_dir, f"{output_filename}_mask_{idx}.png")
        img_mask_p = os.path.join(output_dir, f"{output_filename}_w_mask_{idx}.png")
        if not os.path.exists(mask_p):
            # save the mask
            save_array_to_img(mask, mask_p)

        if not os.path.exists(img_mask_p):
            # save the pointed and masked image
            dpi = plt.rcParams['figure.dpi']
            height, width = img.shape[:2]
            plt.figure(figsize=(width/dpi/0.77, height/dpi/0.77))
            plt.imshow(img)
            plt.axis('off')
            show_mask(plt.gca(), mask, random_color=False)
            plt.savefig(img_mask_p, bbox_inches='tight', pad_inches=0)
            plt.close()

    # inpaint the masked image
    for idx, mask in enumerate(masks):
        mask_p = os.path.join(output_dir, f"{output_filename}_mask_{idx}.png")
        img_inpainted_p = os.path.join(output_dir, f"{output_filename}_remove_{idx}.png")
        if not os.path.exists(img_inpainted_p):
            img_inpainted = inpaint_img_with_lama_loaded(
                model, predict_config, img, mask, device=device)
            save_array_to_img(img_inpainted, img_inpainted_p)


if __name__ == '__main__':
    # parser = argparse.ArgumentParser()
    # parser.add_argument('--annotation_file', type=str, default='../data/vqav2/vqa_k_test_noun_gpt4.jsonl')
    # parser.add_argument('--image_folder', type=str, default='../data/vqav2/images') 
    # parser.add_argument('--mask_folder', type=str, default='../data/vqav2/images/remove_anything/gsam_masks') 
    # parser.add_argument('--output_folder', type=str, default='../data/vqav2/images/remove_anything/lama') 
    # args = parser.parse_args()
    # main(args)
    from fire import Fire
    Fire(main)
